import colorsys
import os
import time

import numpy as np


from PIL import Image, ImageDraw, ImageFont
import jittor as jt
from dectect.model.frcnn import FasterRCNN
from dectect.dect_utils import (resize_image, cvtColor,
                              preprocess_input, show_config)
from dectect.dect_utils import DecodeBox

def get_new_img_size(height, width, img_min_side=600):
    if width <= height:
        f = float(img_min_side) / width
        resized_height = int(f * height)
        resized_width = int(img_min_side)
    else:
        f = float(img_min_side) / height
        resized_width = int(f * width)
        resized_height = int(img_min_side)

    return resized_height, resized_width

# --------------------------------------------#
#   使用自己训练好的模型预测需要修改2个参数
#   model_path和classes_path都需要修改！
#   如果出现shape不匹配
#   一定要注意训练时的NUM_CLASSES、
#   model_path和classes_path参数的修改
# --------------------------------------------#
class FRCNN(object):
    _defaults = {
        # --------------------------------------------------------------------------#
        #   使用自己训练好的模型进行预测一定要修改model_path和classes_path！
        #   model_path指向logs文件夹下的权值文件，classes_path指向model_data下的txt
        #
        #   训练好后logs文件夹下存在多个权值文件，选择验证集损失较低的即可。
        #   验证集损失较低不代表mAP较高，仅代表该权值在验证集上泛化性能较好。
        #   如果出现shape不匹配，同时要注意训练时的model_path和classes_path参数的修改
        # --------------------------------------------------------------------------#
        "model_path": 'D:\study\\track\\faster-rcnn\\faster-rcnn-pytorch-master\logs\ep100-loss0.791-val_loss1.085.pth',
        "classes_path": 'model_data/voc_classes.txt',
        # ---------------------------------------------------------------------#
        #   网络的主干特征提取网络，resnet50或者vgg
        # ---------------------------------------------------------------------#
        "backbone": "resnet50",
        # ---------------------------------------------------------------------#
        #   只有得分大于置信度的预测框会被保留下来
        # ---------------------------------------------------------------------#
        "confidence": 0.3,
        # ---------------------------------------------------------------------#
        #   非极大抑制所用到的nms_iou大小
        # ---------------------------------------------------------------------#
        "nms_iou": 0.3,
        # ---------------------------------------------------------------------#
        #   用于指定先验框的大小
        # ---------------------------------------------------------------------#
        'anchors_size': [8, 16, 32],
        # -------------------------------#
        #   是否使用Cuda
        #   没有GPU可以设置成False
        # -------------------------------#
        "cuda": True,
    }

    @classmethod
    def get_defaults(cls, n):
        if n in cls._defaults:
            return cls._defaults[n]
        else:
            return "Unrecognized attribute name '" + n + "'"

    # ---------------------------------------------------#
    #   初始化faster RCNN
    # ---------------------------------------------------#
    def __init__(self, **kwargs):
        self.__dict__.update(self._defaults)
        for name, value in kwargs.items():
            setattr(self, name, value)
            self._defaults[name] = value
            # ---------------------------------------------------#
        #   获得种类和先验框的数量
        # ---------------------------------------------------#
        if self.mode == 'rgb':
            self.class_names, self.num_classes = ['uav','uavnight'], 2
            self.model_path = '/data01/xjy/code/anti_cp/model_1/vis.pth'
        else:
            self.class_names, self.num_classes = ['UAV'], 1
            self.model_path = '/data01/xjy/code/anti_cp/model_1/ir.pth'

        self.std = jt.Var([0.1, 0.1, 0.2, 0.2]).repeat(self.num_classes + 1)[None]
        # if self.cuda:
        #     self.std = self.std.cuda()
        self.bbox_util = DecodeBox(self.std, self.num_classes)

        # ---------------------------------------------------#
        #   画框设置不同的颜色
        # ---------------------------------------------------#
        hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)]
        self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
        self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors))
        self.generate()

        show_config(**self._defaults)

    # ---------------------------------------------------#
    #   载入模型
    # ---------------------------------------------------#
    def generate(self):
        # -------------------------------#
        #   载入模型与权值
        # -------------------------------#
        self.net = FasterRCNN(self.num_classes, "predict", anchor_scales=self.anchors_size, backbone=self.backbone)
        # state_dict = jt.load(self.model_path)
        # jittor_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
       # self.net.load(self.model_path)
        '''
        self.net = FasterRCNN(self.num_classes, "predict", anchor_scales=self.anchors_size, backbone=self.backbone)
        self.net.load_state_dict(jt.load(self.model_path))
        '''
        self.net = self.net.eval()
        print('{} model, anchors, and classes loaded.'.format(self.model_path))


    # ---------------------------------------------------#
    #   检测图片
    # ---------------------------------------------------#
    def detect_image(self, image, crop=False, count=False):
        # ---------------------------------------------------#
        #   计算输入图片的高和宽
        # ---------------------------------------------------#
        image_shape = np.array(np.shape(image)[0:2])
        # ---------------------------------------------------#
        #   计算resize后的图片的大小，resize后的图片短边为600
        # ---------------------------------------------------#
        input_shape = get_new_img_size(image_shape[0], image_shape[1])
        # ---------------------------------------------------------#
        #   在这里将图像转换成RGB图像，防止灰度图在预测时报错。
        #   代码仅仅支持RGB图像的预测，所有其它类型的图像都会转化成RGB
        # ---------------------------------------------------------#
        image = cvtColor(image)
        # ---------------------------------------------------------#
        #   给原图像进行resize，resize到短边为600的大小上
        # ---------------------------------------------------------#
        image_data = resize_image(image, [input_shape[1], input_shape[0]])
        # ---------------------------------------------------------#
        #   添加上batch_size维度
        # ---------------------------------------------------------#
        image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)


        images = jt.var(image_data)


        # -------------------------------------------------------------#
        #   roi_cls_locs  建议框的调整参数
        #   roi_scores    建议框的种类得分
        #   rois          建议框的坐标
        # -------------------------------------------------------------#
        roi_cls_locs, roi_scores, rois, _ = self.net(images)
        # -------------------------------------------------------------#
        #   利用classifier的预测结果对建议框进行解码，获得预测框
        # -------------------------------------------------------------#
        results = self.bbox_util.forward(roi_cls_locs, roi_scores, rois, image_shape, input_shape,
                                            nms_iou=self.nms_iou, confidence=self.confidence)
        # ---------------------------------------------------------#
        #   如果没有检测出物体，返回原图
        # ---------------------------------------------------------#
        if len(results[0]) <= 0:
            return [0], [0,0,0,0]

        top_label = np.array(results[0][:, 5], dtype='int32')
        top_conf = results[0][:, 4]
        top_boxes = results[0][:, :4]

        # ---------------------------------------------------------#
        #   设置字体与边框厚度
        # ---------------------------------------------------------#
        font = ImageFont.truetype(font='model_data/simhei.ttf',
                                  size=np.floor(3e-2 * image.size[1] + 0.5).astype('int32'))
        thickness = int(max((image.size[0] + image.size[1]) // np.mean(input_shape), 1))
        # ---------------------------------------------------------#
        #   计数
        # ---------------------------------------------------------#
        if count:
            print("top_label:", top_label)
            classes_nums = np.zeros([self.num_classes])
            for i in range(self.num_classes):
                num = np.sum(top_label == i)
                if num > 0:
                    print(self.class_names[i], " : ", num)
                classes_nums[i] = num
            print("classes_nums:", classes_nums)

        return top_conf, top_boxes


    def get_FPS(self, image, test_interval):
        # ---------------------------------------------------#
        #   计算输入图片的高和宽
        # ---------------------------------------------------#
        image_shape = np.array(np.shape(image)[0:2])
        input_shape = get_new_img_size(image_shape[0], image_shape[1])
        # ---------------------------------------------------------#
        #   在这里将图像转换成RGB图像，防止灰度图在预测时报错。
        #   代码仅仅支持RGB图像的预测，所有其它类型的图像都会转化成RGB
        # ---------------------------------------------------------#
        image = cvtColor(image)

        # ---------------------------------------------------------#
        #   给原图像进行resize，resize到短边为600的大小上
        # ---------------------------------------------------------#
        image_data = resize_image(image, [input_shape[1], input_shape[0]])
        # ---------------------------------------------------------#
        #   添加上batch_size维度
        # ---------------------------------------------------------#
        image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)


        images =jt.var(image_data)


        roi_cls_locs, roi_scores, rois, _ = self.net(images)
        # -------------------------------------------------------------#
        #   利用classifier的预测结果对建议框进行解码，获得预测框
        # -------------------------------------------------------------#
        results = self.bbox_util.forward(roi_cls_locs, roi_scores, rois, image_shape, input_shape,
                                            nms_iou=self.nms_iou, confidence=self.confidence)
        t1 = time.time()
        for _ in range(test_interval):
                roi_cls_locs, roi_scores, rois, _ = self.net(images)
                # -------------------------------------------------------------#
                #   利用classifier的预测结果对建议框进行解码，获得预测框
                # -------------------------------------------------------------#
                results = self.bbox_util.forward(roi_cls_locs, roi_scores, rois, image_shape, input_shape,
                                                 nms_iou=self.nms_iou, confidence=self.confidence)

        t2 = time.time()
        tact_time = (t2 - t1) / test_interval
        return tact_time

    # ---------------------------------------------------#
    #   检测图片
    # ---------------------------------------------------#
    def get_map_txt(self, image_id, image, class_names, map_out_path):
        f = open(os.path.join(map_out_path, "detection-results/" + image_id + ".txt"), "w")
        # ---------------------------------------------------#
        #   计算输入图片的高和宽
        # ---------------------------------------------------#
        image_shape = np.array(np.shape(image)[0:2])
        input_shape = get_new_img_size(image_shape[0], image_shape[1])
        # ---------------------------------------------------------#
        #   在这里将图像转换成RGB图像，防止灰度图在预测时报错。
        #   代码仅仅支持RGB图像的预测，所有其它类型的图像都会转化成RGB
        # ---------------------------------------------------------#
        image = cvtColor(image)

        # ---------------------------------------------------------#
        #   给原图像进行resize，resize到短边为600的大小上
        # ---------------------------------------------------------#
        image_data = resize_image(image, [input_shape[1], input_shape[0]])
        # ---------------------------------------------------------#
        #   添加上batch_size维度
        # ---------------------------------------------------------#
        image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)

        images =jt.var(image_data)


        roi_cls_locs, roi_scores, rois, _ = self.net(images)
        # -------------------------------------------------------------#
        #   利用classifier的预测结果对建议框进行解码，获得预测框
        # -------------------------------------------------------------#
        results = self.bbox_util.forward(roi_cls_locs, roi_scores, rois, image_shape, input_shape,
                                            nms_iou=self.nms_iou, confidence=self.confidence)
        # --------------------------------------#
            #   如果没有检测到物体，则返回原图
            # --------------------------------------#
        if len(results[0]) <= 0:
            return [0], [0,0,0,0]

        top_label = np.array(results[0][:, 5], dtype='int32')
        top_conf = results[0][:, 4]
        top_boxes = results[0][:, :4]

        for i, c in list(enumerate(top_label)):
            predicted_class = self.class_names[int(c)]
            box = top_boxes[i]
            score = str(top_conf[i])

            top, left, bottom, right = box
            if predicted_class not in class_names:
                continue

            f.write("%s %s %s %s %s %s\n" % (
            predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)), str(int(bottom))))

        f.close()
        return
